#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-strict


from __future__ import annotations

import asyncio
import time
import unittest
from typing import Sequence

from derived.thrift_clients import DerivedTestingService
from derived.thrift_services import DerivedTestingServiceInterface
from folly.iobuf import IOBuf

from stack_args.thrift_clients import StackService
from stack_args.thrift_services import StackServiceInterface
from stack_args.thrift_types import simple
from testing.thrift_clients import TestingService
from testing.thrift_services import TestingServiceInterface
from testing.thrift_types import Color, easy, SimpleError
from thrift.lib.python.test.event_handlers.helper import ThrowHelper, ThrowHelperHandler
from thrift.lib.python.test.test_server import TestServer
from thrift.py3.server import get_context, SocketAddress
from thrift.python.client import get_client
from thrift.python.common import Priority, RpcOptions
from thrift.python.exceptions import ApplicationError
from thrift.python.server import ServiceInterface


class Handler(TestingServiceInterface):
    async def invert(self, value: bool) -> bool:
        ctx = get_context()
        if "from client" in ctx.read_headers:
            ctx.set_header("from server", "with love")
        return not value

    async def getName(self) -> str:
        ctx = get_context()
        ctx.set_header("contextvar", "true")
        return "Testing"

    async def getMethodName(self) -> str:
        ctx = get_context()
        return ctx.method_name

    async def getRequestId(self) -> str:
        ctx = get_context()
        return ctx.request_id

    async def getRequestTimeout(self) -> float:
        ctx = get_context()
        return ctx.request_timeout

    async def shutdown(self) -> None:
        pass

    async def complex_action(
        self, first: str, second: str, third: int, fourth: str
    ) -> int:
        return third

    async def takes_a_list(self, ints: Sequence[int]) -> None:
        raise SimpleError(color=Color.red)

    async def take_it_easy(self, how: int, what: easy) -> None:
        pass

    async def pick_a_color(self, color: Color) -> None:
        pass

    async def int_sizes(self, one: int, two: int, three: int, four: int) -> None:
        pass

    async def hard_error(self, valid: bool) -> None:
        pass

    async def renamed_func(self, ret: bool) -> bool:
        return ret

    async def getPriority(self) -> int:
        ctx = get_context()
        return ctx.priority.value


class DerivedHandler(Handler, DerivedTestingServiceInterface):
    async def getName(self) -> str:
        return "DerivedTesting"

    async def derived_pick_a_color(self, color: Color) -> Color:
        return color


def local_server(handler: ServiceInterface | None = None) -> TestServer:
    if handler is None:
        handler = Handler()
    return TestServer(handler=handler, ip="::1")


def default_server() -> TestServer:
    # note in this case, port is set to 0
    return TestServer(handler=Handler())


class ClientServerTests(unittest.IsolatedAsyncioTestCase):
    """
    These are tests where a client and server talk to each other
    """

    async def test_get_context(self) -> None:
        async with local_server() as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(TestingService, host=ip, port=port) as client:
                options = RpcOptions()
                options.timeout = 100.0

                self.assertEqual("Testing", await client.getName(rpc_options=options))
                self.assertEqual("true", options.read_headers["contextvar"])
                self.assertEqual(
                    "getMethodName",
                    await client.getMethodName(),
                )
                # requestId is a 16 char wide hex string
                self.assertEqual(
                    len(await client.getRequestId()),
                    16,
                )
                self.assertEqual(
                    100.0,
                    await client.getRequestTimeout(rpc_options=options),
                )

        handler = Handler()  # so we can call it outside the thrift server
        with self.assertRaises(LookupError):
            await handler.getName()

    async def test_rpc_headers(self) -> None:
        async with local_server() as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(TestingService, host=ip, port=port) as client:
                options = RpcOptions()
                options.set_header("from client", "with love")
                self.assertFalse(await client.invert(True, rpc_options=options))
                self.assertIn("from server", options.read_headers)

    async def test_server_localhost(self) -> None:
        async with local_server() as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(TestingService, host=ip, port=port) as client:
                self.assertTrue(await client.invert(False))
                self.assertFalse(await client.invert(True))
                # TODO (ffrancet): after RPC headers are supported, check uex and uexw
                with self.assertRaises(SimpleError):
                    await client.takes_a_list([])

    async def test_no_client_aexit(self) -> None:
        async with default_server() as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            client = get_client(TestingService, host=ip, port=port)
            await client.__aenter__()
            self.assertTrue(await client.invert(False))
            self.assertFalse(await client.invert(True))

    # If we do not abort here then good

    async def test_client_aexit_no_await(self) -> None:
        """
        This actually handles the case if __aexit__ is not awaited
        """

        async with default_server() as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            client = get_client(TestingService, host=ip, port=port)
            await client.__aenter__()
            self.assertTrue(await client.invert(False))
            self.assertFalse(await client.invert(True))
            # pyre-fixme[1001]: Async expression is not awaited.
            _ = client.__aexit__(None, None, None)
            del client  # If we do not abort here then good

    async def test_no_client_no_aenter(self) -> None:
        """
        This covers if aenter was canceled since those two are the same really
        """

        async with default_server() as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            get_client(TestingService, host=ip, port=port)

    # If we do not abort here then good

    async def test_derived_service(self) -> None:
        """
        This tests calling methods from a derived service
        """

        async with local_server(handler=DerivedHandler()) as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(
                DerivedTestingService,
                host=ip,
                port=port,
            ) as client:
                self.assertEqual(await client.getName(), "DerivedTesting")
                self.assertEqual(
                    await client.derived_pick_a_color(Color.red), Color.red
                )

    async def test_renamed_func(self) -> None:
        async with local_server() as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(TestingService, host=ip, port=port) as client:
                self.assertEqual(True, await client.renamed_func(True))

    async def test_queue_timeout(self) -> None:
        """
        This tests whether queue timeout functions properly.
        """

        class SlowDerivedHandler(Handler, DerivedTestingServiceInterface):
            async def getName(self) -> str:
                time.sleep(1)
                return "SlowDerivedTesting"

            async def derived_pick_a_color(self, color: Color) -> Color:
                return color

        testing = TestServer(handler=SlowDerivedHandler())
        testing.server.set_queue_timeout(0.01)

        async def client_call(sa: SocketAddress) -> str:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(DerivedTestingService, host=ip, port=port) as client:
                try:
                    return await client.getName()
                except ApplicationError as err:
                    if "Queue Timeout" in str(err):
                        return "Queue Timeout"
                    else:
                        return ""

        async def clients_run(server: TestServer) -> None:
            async with server as sa:
                results = await asyncio.gather(
                    client_call(sa),
                    client_call(sa),
                    client_call(sa),
                    client_call(sa),
                    client_call(sa),
                )
                self.assertIn("Queue Timeout", results)

        await clients_run(testing)

    async def test_cancelled_task(self) -> None:
        """
        This tests whether cancelled tasks are handled properly.
        """
        cancelledMessage: str = "I have been cancelled"

        class CancelHandler(Handler):
            async def getName(self) -> str:
                raise asyncio.CancelledError(
                    cancelledMessage
                )  # Pretend that this is some await call that gets cancelled

        async with local_server(handler=CancelHandler()) as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(TestingService, host=ip, port=port) as client:
                with self.assertRaises(ApplicationError) as ex:
                    await client.getName()
                self.assertEqual(
                    ex.exception.message,
                    f"Application was cancelled on the server with message: {cancelledMessage}",
                )

    async def test_unexpected_error(self) -> None:
        """
        This tests whether unexpected errors handled properly.
        """
        errMessage: str = "I am an error"

        class ErrorHandler(TestingServiceInterface):
            async def getName(self) -> str:
                raise Exception(errMessage)

        async with local_server(handler=ErrorHandler()) as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(TestingService, host=ip, port=port) as client:
                with self.assertRaises(ApplicationError) as ex:
                    await client.getName()
                self.assertEqual(
                    ex.exception.message,
                    f"Exception('{errMessage}')",
                )

    async def test_request_with_default_rpc_options(self) -> None:
        async with local_server() as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(TestingService, host=ip, port=port) as client:
                timeout = await client.getRequestTimeout()
                self.assertEqual(timeout, 0.0)
                priority = await client.getPriority()
                self.assertEqual(Priority(priority), Priority.N_PRIORITIES)

    async def test_request_with_specified_rpc_options(self) -> None:
        async with local_server() as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(TestingService, host=ip, port=port) as client:
                options = RpcOptions()
                options.timeout = 15.0
                options.priority = Priority.BEST_EFFORT
                timeout = await client.getRequestTimeout(rpc_options=options)
                self.assertEqual(timeout, 15.0)
                priority = await client.getPriority(rpc_options=options)
                self.assertEqual(Priority(priority), Priority.BEST_EFFORT)

    async def test_client_event_handler_throw(self) -> None:
        for handler in ThrowHelperHandler:
            async with local_server() as sa:
                ip, port = sa.ip, sa.port
                self.assertIsNotNone(ip)
                self.assertIsNotNone(port)
                async with get_client(TestingService, host=ip, port=port) as client:
                    with ThrowHelper(handler):
                        with self.assertRaisesRegex(ApplicationError, handler.value):
                            await client.complex_action("1", "2", 3, "4")


class StackHandler(StackServiceInterface):
    async def add_to(self, lst: Sequence[int], value: int) -> Sequence[int]:
        return [x + value for x in lst]

    async def get_simple(self) -> simple:
        return simple(val=66)

    async def get_simple_no_sa(self) -> simple:
        return simple(val=88)

    async def take_simple(self, smpl: simple) -> None:
        if smpl.val != 10:
            raise Exception("WRONG")

    async def get_iobuf(self) -> IOBuf:
        return IOBuf(b"abc")

    async def take_iobuf(self, val: IOBuf) -> None:
        if b"".join(val) != b"cba":
            raise Exception("WRONG")

    # currently unsupported by cpp backend:
    # async def get_iobuf_ptr(self) -> IOBuf:
    #     return IOBuf(b'xyz')

    async def take_iobuf_ptr(self, val: IOBuf) -> None:
        if b"".join(val) != b"zyx":
            raise Exception("WRONG")


class ClientStackServerTests(unittest.IsolatedAsyncioTestCase):
    """
    These are tests where a client and server(stack_arguments) talk to each other
    """

    async def test_server_localhost(self) -> None:
        async with local_server(handler=StackHandler()) as sa:
            ip, port = sa.ip, sa.port
            assert ip and port
            async with get_client(StackService, host=ip, port=port) as client:
                self.assertEqual(
                    (3, 4, 5, 6), await client.add_to(lst=(1, 2, 3, 4), value=2)
                )
                self.assertEqual(66, (await client.get_simple()).val)
                self.assertEqual((await client.get_simple_no_sa()).val, 88)
                await client.take_simple(simple(val=10))
                self.assertEqual(b"abc", b"".join(await client.get_iobuf()))
                await client.take_iobuf(IOBuf(b"cba"))
                # currently unsupported by cpp backend:
                # self.assertEqual(b'xyz', (await client.get_iobuf_ptr()))
                await client.take_iobuf_ptr(IOBuf(b"zyx"))
